from logging import getLogger
from pathlib import PurePath
from typing import Callable, Collection, Set

from common import OperatingSystem
from common.agent_events import AgentEventTag
from common.credentials import Credentials
from common.tags import (
    BRUTE_FORCE_T1110_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    INGRESS_TOOL_TRANSFER_T1105_TAG,
    NETWORK_SHARE_DISCOVERY_T1135_TAG,
    REMOTE_SERVICES_T1021_TAG,
    SYSTEM_SERVICES_T1569_TAG,
)
from infection_monkey.exploit.tools import (
    IRemoteAccessClient,
    RemoteAuthenticationError,
    RemoteCommandExecutionError,
)
from infection_monkey.i_puppet import TargetHost

from .smb_remote_access_client import SMBRemoteAccessClient
from .wmi_client import WMIClient
from .wmi_options import WMIOptions

LOGIN_TAGS = {
    REMOTE_SERVICES_T1021_TAG,
    BRUTE_FORCE_T1110_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
}
SHARE_DISCOVERY_TAGS = {
    NETWORK_SHARE_DISCOVERY_T1135_TAG,
}
COPY_FILE_TAGS = {
    INGRESS_TOOL_TRANSFER_T1105_TAG,
}
EXECUTION_TAGS = {
    REMOTE_SERVICES_T1021_TAG,
    EXPLOITATION_OF_REMOTE_SERVICES_T1210_TAG,
    SYSTEM_SERVICES_T1569_TAG,
}

logger = getLogger(__name__)


class WMIRemoteAccessClient(IRemoteAccessClient):
    def __init__(
        self,
        host: TargetHost,
        options: WMIOptions,
        command_builder: Callable[[PurePath], str],
        smb_client: SMBRemoteAccessClient,
        wmi_client: WMIClient,
    ):
        self._host = host
        self._build_command = command_builder
        self._options = options
        self._smb_client = smb_client
        self._wmi_client = wmi_client

    def login(self, credentials: Credentials, tags: Set[AgentEventTag]):
        tags.update(LOGIN_TAGS)

        try:
            self._smb_client.login(credentials, tags)
            self._wmi_client.login(self._host, credentials)
        except Exception as err:
            raise RemoteAuthenticationError from err

    def get_os(self) -> OperatingSystem:
        return OperatingSystem.WINDOWS

    def execute_agent(self, agent_binary_path: PurePath, tags: Set[AgentEventTag]):
        if not self._wmi_client.connected():
            raise RemoteCommandExecutionError("WMI client is not connected")

        try:
            tags.update(EXECUTION_TAGS)
            command = self._build_command(agent_binary_path)
            self._wmi_client.execute_remote_process(command, str(agent_binary_path.parent))
        except Exception as err:
            raise RemoteCommandExecutionError from err

    def copy_file(self, file: bytes, destination_path: PurePath, tags: Set[AgentEventTag]):
        self._smb_client.copy_file(file, destination_path, tags)

    def get_writable_paths(self) -> Collection[PurePath]:
        return self._smb_client.get_writable_paths()
